In [1]:
import os
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
In [2]:
data_path = "/home/sattvik/ood_cv/data"
In [25]:
class Dataset:
    def __init__(self, data_path):
        self.data_path = data_path
        print(data_path)
    def show_image(self, image_name, title):
        image = mpimg.imread(image_name)
        plt.figure()
        plt.title(title)
        plt.imshow(image)
    def show_images(self, image_names, images_path, title):
        for image_name in image_names:
            self.show_image(os.path.join(images_path, image_name), title)
    def show_random_partition_class(self, partition, data_class, n, val_type='context'):
        if partition == 'train':
            images_path = os.path.join(self.data_path, 'ROBIN-cls-{}'.format(partition), data_class)
            image_names = os.listdir(images_path)
            if n<len(image_names):
                self.show_images(random.sample(image_names, n), images_path, data_class)
            else:
                self.show_images(image_names, images_path, data_class)
        elif partition == 'val':
            images_path = os.path.join(self.data_path, 'ROBIN-cls-{}'.format(partition), val_type, data_class)
            image_names = os.listdir(images_path)
            if n<len(image_names):
                self.show_images(random.sample(image_names, n), images_path, data_class+'_'+val_type)
            else:
                self.show_images(image_names, images_path, data_class+'_'+val_type)
                
    def visualize_train_classes(self, n):
        images_path = os.path.join(self.data_path, 'ROBIN-cls-train')
        class_names = os.listdir(images_path)
        print(class_names)
        for class_name in class_names:
            if class_name.startswith('.'):
                continue
            self.show_random_partition_class('train', class_name, n)
            
    def visualize_val_classes(self, n):
        types_path = os.path.join(self.data_path, 'ROBIN-cls-val')
        type_names = os.listdir(types_path)
        for type_name in type_names:
            if type_name.startswith('.'):
                continue
            images_path = os.path.join(self.data_path, 'ROBIN-cls-val', type_name)
            class_names = os.listdir(images_path)
            for class_name in class_names:
                if class_name.startswith('.'):
                    continue
                self.show_random_partition_class('val', class_name, n, type_name)
In [26]:
dataset = Dataset(data_path)
/home/sattvik/ood_cv/data
In [28]:
dataset.visualize_train_classes(5)
['car', 'diningtable', 'boat', 'bus', 'bicycle', '.DS_Store', '.ipynb_checkpoints', 'motorbike', 'aeroplane', 'train', 'sofa', 'chair']
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  import sys
In [29]:
dataset.visualize_val_classes(5)
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  import sys
In [ ]: